import os

import cv2
import numpy as np
import pandas as pd
from imutils import auto_canny

import utils
from src import choice_question_recognition, my_vgg, exam_num_recognition

my_model = my_vgg.get_trained_my_vgg16()


def init_process_image(image_path: str):
    """
    对图片进行处理得到轮廓
    :param image_path:需要处理的图片的地址
    :return image:读取的图片, edged:识别的边缘
    """
    image = cv2.imread(image_path)
    kernel = np.ones((5, 5), np.uint8)
    result_image = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel, iterations=15)
    # result_image = cv2.morphologyEx(result_image, cv2.MORPH_DILATE, kernel, iterations=2)
    gray = cv2.cvtColor(result_image, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 1)
    edged = auto_canny(blurred)
    # utils.show_image(edged)
    return image, edged


def get_contours(image, edged):
    """
    获取图片的四个角
    :param image: 答题卡原图
    :param edged: 答题卡的边缘
    :return corners: 四个角的点坐标
    """
    contours, _ = cv2.findContours(edged.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    answer_card_corners = []
    for contour in contours:
        temp_image = utils.get_image_by_contour(image, contour)
        # utils.show_image(temp_image)
        index, _ = my_vgg.judge_by_model(temp_image, my_model)
        if index is utils.label_dict['corner']:
            answer_card_corners.append(contour)
    # res = cv2.drawContours(image.copy(), answer_card_corners, -1, (0, 0, 255), 2)
    # utils.show_image(res)
    corners = []
    for c in answer_card_corners:
        peri = cv2.arcLength(c, True)
        approx = cv2.approxPolyDP(c, 0.02 * peri, True)
        box = cv2.boundingRect(approx)
        corners.append(utils.get_center_of_box(box))
        # temp_image = utils.get_image_by_contour(image, c)
        # cv2.imwrite("../dataset/corner/{}.jpg".format(utils.corner), temp_image)
        # utils.corner += 1
    return corners


def get_answer_card(image, corners):
    """
    获取整张图片的轮廓
    :param image: 答题卡原图
    :param corners: 答题卡四个点的坐标
    :return warped: 经过透视变换得到的答题卡
    """
    rect = utils.order_points(corners)  # 获得一致的顺序的点并分别解包他们
    (tl, tr, br, bl) = rect

    # 计算新图像的宽度(x)
    widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))  # 右下和左下之间距离
    widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))  # 右上和左上之间距离
    maxWidth = max(int(widthA), int(widthB))  # 取大者

    # 计算新图像的高度(y)
    heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))  # 右上和右下之间距离
    heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))  # 左上和左下之间距离
    maxHeight = max(int(heightA), int(heightB))

    # 有了新图像的尺寸, 构造透视变换后的顶点集合
    dst = np.array(
        [
            [0, 0],  # -------------------------左上
            [maxWidth - 1, 0],  # --------------右上
            [maxWidth - 1, maxHeight - 1],  # --右下
            [0, maxHeight - 1]
        ],  # ------------左下
        dtype="float32")

    M = cv2.getPerspectiveTransform(rect, dst)  # 计算透视变换矩阵
    warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))  # 执行透视变换

    return warped  # 返回透视变换后的图像


def get_answer_recognition(answer_card):
    """
    获取填涂区域
    :param answer_card: 裁剪变换之后的答题卡
    :return choice: 选择题图片列表, person_info: 考生号图片
    """
    kernel = np.ones((3, 3), np.uint8)
    result_image = cv2.morphologyEx(answer_card, cv2.MORPH_OPEN, kernel, iterations=18)
    gray = cv2.cvtColor(result_image, cv2.COLOR_BGR2GRAY)
    # utils.show_image(result_image)
    ret, thresh = cv2.threshold(gray, -1, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
    # utils.show_image(thresh)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    res = cv2.drawContours(answer_card.copy(), contours, -1, (0, 0, 255), 2)
    # utils.show_image(res)
    # print(len(contours))
    choice = []
    person_info = None
    info_predict = 0.0
    # cnt = 0
    for c in contours:
        box = cv2.boundingRect(c)
        temp_image = utils.get_image_by_contour(answer_card, c)
        if temp_image.shape[0] * temp_image.shape[1] == 0:
            continue
        # utils.show_image(temp_image)
        index, predict = my_vgg.judge_by_model(temp_image, my_model)
        if index is utils.label_dict["choices_question"]:
            (x, y, w, h) = box
            x -= 2
            w += 4
            box = (x, y, w, h)
            temp_image = utils.get_image_by_box(answer_card, box)
            choice.append(temp_image)
        if index is utils.label_dict['exam_num'] and predict > info_predict:
            person_info = temp_image
            info_predict = predict
    return choice, person_info


def answer_card_recognition(image_path: str):
    image, edged = init_process_image(image_path)
    cv2.imwrite("../pic/answer_card.jpg", image)
    # text = get_ocr_text(image)
    corners = get_contours(image, edged)
    print(corners)
    wraped = get_answer_card(image, corners)
    choice, info = get_answer_recognition(wraped)
    exam_num = ""
    answers = []
    cnt = 0
    cv2.imwrite("../pic/exam_num/info.jpg", info)
    exam_num = exam_num_recognition.exam_num_recognition(info)
    for c in choice:
        # utils.show_image(c)
        cv2.imwrite("../pic/choices_question/{}.jpg".format(cnt), c)
        cnt += 1
        answers += (choice_question_recognition.choice_question_recognition(c, my_model))
        # cv2.imwrite("../dataset/choices_question/{}.jpg".format(utils.choices_question), c)
        # utils.choices_question += 1
    answers = sorted(answers, key=lambda x: x[0])
    answers = dict(answers)
    # print("学号为：" + exam_num)
    # print("识别结果：")
    # print(answers)
    return exam_num, answers


if __name__ == '__main__':
    utils.get_selected_box_size()
    path_dir = "../answer_card"
    names = os.listdir(path_dir)
    # random.shuffle(names)
    names = ["SKM_C36818012910480_0052.jpg"]
    col = ["path", "ID"]
    for i in range(1, 51):
        num = "Q{}".format(i)
        col.append(num)
    result = []
    for name in names:
        item = []
        utils.read_cnt()
        path = path_dir + "/" + name
        item.append(path)
        print(path)
        exam_num, answers = answer_card_recognition(path)
        item.append(exam_num)
        for key in range(1, 51):
            if key in answers.keys():
                answers_item = answers[key]
                one_answer = ""
                for a in answers_item:
                    if one_answer != "":
                        one_answer += "|"
                    one_answer += a
                item.append(one_answer)
            else:
                item.append("")
        utils.save_cnt()
        print(item)
        result.append(item)
    df = pd.DataFrame(result, columns=col)
    df.to_csv("../result.csv", encoding='utf-8')
